# -*- coding: utf-8 -*-
"""
Created on Wed Nov 25 12:17:51 2020

@author: sli85
"""
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd

surge_pred_469=scipy.io.loadmat('surge_prediction_494.mat')
pred_surge=surge_pred_469['predict']
true_surge=surge_pred_469['true']

plt.figure()
plt.plot(np.linspace(0, 96, 97),pred_surge.reshape(97,),label='DNN prediction',lw=2)
plt.plot(np.linspace(0, 96, 97),true_surge.reshape(97,),label='High-fidelity simulation',ls='--',lw=2)
plt.xlim((0,100))
plt.ylim((-0.5,2.5))
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlabel('time (h)',fontsize=14)
plt.ylabel('surge elevation (m)',fontsize=14)
plt.legend(fontsize=12)
plt.show()


wave_pred_469=scipy.io.loadmat('wave_prediction_494.mat')
pred_wave=wave_pred_469['predict']
true_wave=wave_pred_469['true']

plt.figure()
plt.plot(np.linspace(0, 96, 97),pred_wave.reshape(97,),label='DNN prediction',lw=2)
plt.plot(np.linspace(0, 96, 97),true_wave.reshape(97,),label='High-fidelity simulation',ls='--',lw=2)
plt.xlim((0,100))
plt.ylim((0,2.5))
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlabel('time (h)',fontsize=14)
plt.ylabel('significant wave height (m)',fontsize=14)
plt.legend(fontsize=12)
plt.show()






a= scipy.io.loadmat('surge_prediction.mat')
surge_prediction=a['surge_prediction']
b= scipy.io.loadmat('wave_prediction.mat')
wave_prediction=b['wave_prediction']

plt.figure(dpi=200)
plt.hist(surge_prediction[0,:], bins = np.arange(0, 3.25, step=0.25)) 
#plt.title("histogram") 
plt.xticks((np.arange(0, 3.25, step=0.25)),fontsize=10)
plt.yticks(fontsize=10)
plt.xlabel('Surge elevation (m)',fontsize=10)
plt.ylabel('frequency',fontsize=10)
plt.show()

plt.figure(dpi=200)
plt.hist(wave_prediction[0,:], bins = np.arange(0, 3.25, step=0.25)) 
#plt.title("histogram") 
plt.xticks((np.arange(0, 3.25, step=0.25)),fontsize=10)
plt.yticks(fontsize=10)
plt.xlabel('Significant wave height (m)',fontsize=10)
plt.ylabel('frequency',fontsize=10)
plt.show()



###########there is some thing wrong with Matplotlib, not the code, 

fig = plt.figure(figsize=(10,8),dpi=200)
ax = fig.add_subplot(projection='3d')
#x, y = np.random.rand(2, 100) * 4
xedges=np.arange(0, 3.25, step=0.25)
yedges=np.arange(0, 3.25, step=0.25)

hist, xedges, yedges = np.histogram2d(surge_prediction[0,:], wave_prediction[0,:], bins=(xedges, yedges))

# Construct arrays for the anchor positions of the 16 bars.
xpos, ypos = np.meshgrid(xedges[:-1], yedges[:-1], indexing="ij")
xpos = xpos.ravel()
ypos = ypos.ravel()
zpos = 0

# Construct arrays with the dimensions for the 16 bars.
dx = dy = 0.25 * np.ones_like(zpos)
#dx = dy = 0.5 * np.ones_like(zpos)
dz = hist.ravel()
ax.bar3d(xpos, ypos, zpos, dx, dy, dz, zsort='average')

ax.view_init(60, 300)
plt.xlabel('Surge elevation (m)',fontsize=14)
plt.ylabel('Significant wave height (m)',fontsize=14)
ax.set_zlabel('Frequency', fontsize=14)
plt.tight_layout()
plt.show()





fragility_curve=pd.read_csv('Pr.csv')

#Zc range Hs range
Zc = np.arange(-10, 10, 0.50)*0.3048
Hs = np.arange(4, 14, 0.25)*0.3048


from scipy.interpolate import griddata
data = fragility_curve.to_numpy()

X1_grid, X2_grid = np.meshgrid(Zc, Hs)

X1_X2_reshape=(np.vstack((X1_grid.reshape(1,-1),X2_grid.reshape(1,-1)))).T
data_reshape=data.reshape(-1,)

def interpolatedata (x1, x2):
    x1=np.clip(x1, -10*0.3048, 9.5*0.3048)
    x2=np.clip(x2, 4*0.3048, 13.75*0.3048)
    points=np.hstack((np.array(x1).reshape(-1,1),np.array(x2).reshape(-1,1)))
    value= griddata(X1_X2_reshape, data_reshape, points, method='cubic')
    return value[0]

sum_pro_list=[]
for clearance in [2, 2.5, 3]:
    sum_pro=0
    for i in range(surge_prediction.shape[1]):
        relative_surge=clearance-surge_prediction[0,i]
        wave_height=wave_prediction[0,i]
        pro=interpolatedata (relative_surge, wave_height)/10000
        sum_pro=sum_pro+pro
        if i%100==0:
            print(clearance)
            print(i)
    sum_pro_list.append(sum_pro)


